data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
label <- mx.symbol.identity(label, name="label")
fc1 <- mx.symbol.FullyConnected(data=data, num_hidden=1, name="fc1")
fc1 <- mx.symbol.Reshape(data=fc1, shape=c(0), name="fc1_reshape")
perc_err <- mx.symbol.abs(fc1/label-1, name="perc_error")
custom_loss <- mx.symbol.MakeLoss(perc_err, name="loss")
graph.viz(custom_loss, direction="LR", graph.height.px = 160)
model_reg<- mx.model.FeedForward.create(symbol = custom_loss,
X = train_x, y = train_y,
eval.data = list(data=test_x, label=test_y),
ctx = mx.cpu(), num.round = 24,
array.batch.size = 32,
optimizer = "sgd",
learning.rate=0.000001,
momentum=0.9,
wd=0.0001, epoch.end.callback = mx.callback.log.train.metric(1), eval.metric = mx.metric.rmse
)
## Warning in mx.model.select.layout.train(X, y): Auto detect layout of input matrix, use rowmajor..
## Start training with 1 devices
## [1] Train-rmse=24.1800087670197
## [1] Validation-rmse=22.6335869930776
## [2] Train-rmse=24.0300878477281
## [2] Validation-rmse=22.6708376486094
## [3] Train-rmse=24.0707992452394
## [3] Validation-rmse=22.7176125992034
## [4] Train-rmse=24.1181655814092
## [4] Validation-rmse=22.7686102279827
## [5] Train-rmse=24.1685198935669
## [5] Validation-rmse=22.8215426911849
## [6] Train-rmse=24.2202787901638
## [6] Validation-rmse=22.8754213353489
## [7] Train-rmse=24.2720386018322
## [7] Validation-rmse=22.9282619101352
## [8] Train-rmse=24.3217555016691
## [8] Validation-rmse=22.9764323638721
## [9] Train-rmse=24.3670866651649
## [9] Validation-rmse=23.0190863297382
## [10] Train-rmse=24.4068048559524
## [10] Validation-rmse=23.0559792101731
## [11] Train-rmse=24.4398535634287
## [11] Validation-rmse=23.0872859712153
## [12] Train-rmse=24.4668716298295
## [12] Validation-rmse=23.1133190146742
## [13] Train-rmse=24.4886991388051
## [13] Validation-rmse=23.134165832689
## [14] Train-rmse=24.5058598416825
## [14] Validation-rmse=23.1511550106171
## [15] Train-rmse=24.5206837980744
## [15] Validation-rmse=23.1661712487192
## [16] Train-rmse=24.533900910326
## [16] Validation-rmse=23.1800148856976
## [17] Train-rmse=24.5461397618477
## [17] Validation-rmse=23.1927502100125
## [18] Train-rmse=24.5574676317429
## [18] Validation-rmse=23.2044778505744
## [19] Train-rmse=24.567663326875
## [19] Validation-rmse=23.2148402080102
## [20] Train-rmse=24.5766600087251
## [20] Validation-rmse=23.2241971898688
## [21] Train-rmse=24.5848173938563
## [21] Validation-rmse=23.2326224105747
## [22] Train-rmse=24.5920757509759
## [22] Validation-rmse=23.2402938270993
## [23] Train-rmse=24.5984723358433
## [23] Validation-rmse=23.247121509478
## [24] Train-rmse=24.6042160592766
## [24] Validation-rmse=23.253354949684